# --------------------------------------------------------
# Fast R-CNN
# Copyright (c) 2015 Microsoft
# Licensed under The MIT License [see LICENSE for details]
# Written by Ross Girshick
# --------------------------------------------------------

"""Test a Fast R-CNN network on an imdb (image database)."""

from fast_rcnn.config import cfg, get_output_dir
from fast_rcnn.bbox_transform import clip_boxes, bbox_transform_inv
import argparse
from utils.timer import Timer
import numpy as np
import cv2
import caffe
from fast_rcnn.nms_wrapper import nms, soft_nms
import cPickle
from utils.blob import im_list_to_blob
import os
from utils.cython_bbox import bbox_overlaps

def _get_image_blob(im):
    """Converts an image into a network input.

    Arguments:
        im (ndarray): a color image in BGR order

    Returns:
        blob (ndarray): a data blob holding an image pyramid
        im_scale_factors (list): list of image scales (relative to im) used
            in the image pyramid
    """
    im_orig = im.astype(np.float32, copy=True)
    im_orig -= cfg.PIXEL_MEANS

    im_shape = im_orig.shape
    im_size_min = np.min(im_shape[0:2])
    im_size_max = np.max(im_shape[0:2])

    processed_ims = []
    im_scale_factors = []

    for target_size in cfg.TEST.SCALES:
        im_scale = float(target_size) / float(im_size_min)
        # Prevent the biggest axis from being more than MAX_SIZE
        if np.round(im_scale * im_size_max) > cfg.TEST.MAX_SIZE:
            im_scale = float(cfg.TEST.MAX_SIZE) / float(im_size_max)
        im = cv2.resize(im_orig, None, None, fx=im_scale, fy=im_scale,
                        interpolation=cv2.INTER_LINEAR)
        im_scale_factors.append(im_scale)
        processed_ims.append(im)

    # Create a blob to hold the input images
    blob = im_list_to_blob(processed_ims)

    return blob, np.array(im_scale_factors)

def _get_rois_blob(im_rois, im_scale_factors):
    """Converts RoIs into network inputs.

    Arguments:
        im_rois (ndarray): R x 4 matrix of RoIs in original image coordinates
        im_scale_factors (list): scale factors as returned by _get_image_blob

    Returns:
        blob (ndarray): R x 5 matrix of RoIs in the image pyramid
    """
    rois, levels = _project_im_rois(im_rois, im_scale_factors)
    rois_blob = np.hstack((levels, rois))
    return rois_blob.astype(np.float32, copy=False)

def _project_im_rois(im_rois, scales):
    """Project image RoIs into the image pyramid built by _get_image_blob.

    Arguments:
        im_rois (ndarray): R x 4 matrix of RoIs in original image coordinates
        scales (list): scale factors as returned by _get_image_blob

    Returns:
        rois (ndarray): R x 4 matrix of projected RoI coordinates
        levels (list): image pyramid levels used by each projected RoI
    """
    im_rois = im_rois.astype(np.float, copy=False)

    if len(scales) > 1:
        widths = im_rois[:, 2] - im_rois[:, 0] + 1
        heights = im_rois[:, 3] - im_rois[:, 1] + 1

        areas = widths * heights
        scaled_areas = areas[:, np.newaxis] * (scales[np.newaxis, :] ** 2)
        diff_areas = np.abs(scaled_areas - 224 * 224)
        levels = diff_areas.argmin(axis=1)[:, np.newaxis]
    else:
        levels = np.zeros((im_rois.shape[0], 1), dtype=np.int)

    rois = im_rois * scales[levels]

    return rois, levels

def _get_blobs(im, rois):
    """Convert an image and RoIs within that image into network inputs."""
    blobs = {'data' : None, 'rois' : None}
    blobs['data'], im_scale_factors = _get_image_blob(im)
    if not cfg.TEST.HAS_RPN:
        blobs['rois'] = _get_rois_blob(rois, im_scale_factors)
    return blobs, im_scale_factors

def im_detect(net, im, boxes=None, force_boxes=False):
    """Detect object classes in an image given object proposals.

    Arguments:
        net (caffe.Net): Fast R-CNN network to use
        im (ndarray): color image to test (in BGR order)
        boxes (ndarray): R x 4 array of object proposals or None (for RPN)

    Returns:
        scores (ndarray): R x K array of object class scores (K includes
            background as object category 0)
        boxes (ndarray): R x (4*K) array of predicted bounding boxes
        attr_scores (ndarray): R x M array of attribute class scores
    """
    blobs, im_scales = _get_blobs(im, boxes)
    if force_boxes:
        blobs['rois'] = _get_rois_blob(boxes, im_scales)

    # When mapping from image ROIs to feature map ROIs, there's some aliasing
    # (some distinct image ROIs get mapped to the same feature ROI).
    # Here, we identify duplicate feature ROIs, so we only compute features
    # on the unique subset.
    if cfg.DEDUP_BOXES > 0 and not cfg.TEST.HAS_RPN:
        v = np.array([1, 1e3, 1e6, 1e9, 1e12])
        hashes = np.round(blobs['rois'] * cfg.DEDUP_BOXES).dot(v)
        _, index, inv_index = np.unique(hashes, return_index=True,
                                        return_inverse=True)
        blobs['rois'] = blobs['rois'][index, :]
        boxes = boxes[index, :]

    im_blob = blobs['data']
    blobs['im_info'] = np.array(
        [[im_blob.shape[2], im_blob.shape[3], im_scales[0]]],
        dtype=np.float32)

    # reshape network inputs
    net.blobs['data'].reshape(*(blobs['data'].shape))
    if 'im_info' in net.blobs:
        net.blobs['im_info'].reshape(*(blobs['im_info'].shape))
    if force_boxes or not cfg.TEST.HAS_RPN:
        net.blobs['rois'].reshape(*(blobs['rois'].shape))

    # do forward
    forward_kwargs = {'data': blobs['data'].astype(np.float32, copy=False)}
    if 'im_info' in net.blobs:
        forward_kwargs['im_info'] = blobs['im_info'].astype(np.float32, copy=False)
    if force_boxes or not cfg.TEST.HAS_RPN:
        forward_kwargs['rois'] = blobs['rois'].astype(np.float32, copy=False)
    blobs_out = net.forward(**forward_kwargs)

    if cfg.TEST.HAS_RPN and not force_boxes:
        assert len(im_scales) == 1, "Only single-image batch implemented"
        rois = net.blobs['rois'].data.copy()
        # unscale back to raw image space
        boxes = rois[:, 1:5] / im_scales[0]

    if cfg.TEST.SVM:
        # use the raw scores before softmax under the assumption they
        # were trained as linear SVMs
        scores = net.blobs['cls_score'].data
    else:
        # use softmax estimated probabilities
        scores = blobs_out['cls_prob']

    if cfg.TEST.BBOX_REG:
        # Apply bounding-box regression deltas
        box_deltas = blobs_out['bbox_pred']
        pred_boxes = bbox_transform_inv(boxes, box_deltas)
        pred_boxes = clip_boxes(pred_boxes, im.shape)
    else:
        # Simply repeat the boxes, once for each class
        pred_boxes = np.tile(boxes, (1, scores.shape[1]))

    if cfg.DEDUP_BOXES > 0 and not cfg.TEST.HAS_RPN:
        # Map scores and predictions back to the original set of boxes
        scores = scores[inv_index, :]
        pred_boxes = pred_boxes[inv_index, :]
        
    if 'attr_prob' in net.blobs:
        attr_scores = blobs_out['attr_prob']
    else:
        attr_scores = None
        
    if 'rel_prob' in net.blobs:
        rel_scores = blobs_out['rel_prob']
    else:
        rel_scores = None

    return scores, pred_boxes, attr_scores, rel_scores

def vis_detections(im, class_name, dets, thresh=0.3, filename='vis.png'):
    """Visual debugging of detections."""
    import matplotlib.pyplot as plt
    im = im[:, :, (2, 1, 0)]
    plt.cla()
    plt.imshow(im)
    for i in xrange(np.minimum(10, dets.shape[0])):
        bbox = dets[i, :4]
        score = dets[i, -1]
        if score > thresh:
            plt.gca().add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='g', linewidth=3)
                )
    plt.title('{}  {:.3f}'.format(class_name, score))
    plt.show()
    plt.savefig('./data/vis/%s' % filename)
    
def vis_multiple(im, class_names, all_boxes, filename='vis.png'):
    """Visual debugging of detections."""
    
    print filename
    import matplotlib.pyplot as plt
    im = im[:, :, (2, 1, 0)]
    plt.cla()
    plt.imshow(im)
    
    max_boxes = 10
    image_scores = np.hstack([all_boxes[j][:, 4]
          for j in xrange(1, len(class_names))])
    if len(image_scores) > 10:
        image_thresh = np.sort(image_scores)[-max_boxes]
    else:
        image_thresh = -np.inf
    for j in xrange(1, len(class_names)):
        keep = np.where(all_boxes[j][:, 4] >= image_thresh)[0]
        dets = all_boxes[j][keep, :]
        for i in range(dets.shape[0]):
            bbox = dets[i, :4]
            score = dets[i, -1]
            plt.gca().add_patch(
                plt.Rectangle((bbox[0], bbox[1]),
                              bbox[2] - bbox[0],
                              bbox[3] - bbox[1], fill=False,
                              edgecolor='red', linewidth=1)
                    )
                    
            plt.gca().text(bbox[0], bbox[1] - 2,
                        '{:s} {:.3f}'.format(class_names[j], score),
                        bbox=dict(facecolor='blue', alpha=0.5),
                        fontsize=8, color='white')        
                
    plt.title('Best %d Attributes using gt boxes' % max_boxes)
    plt.show()
    plt.savefig('./data/vis/%s' % filename)    
   
  
def vis_relations(im, class_names, box_proposals, scores, filename='vis.png'):

    n = box_proposals.shape[0]
    assert scores.shape[0] == n*n
    print filename
    import matplotlib.pyplot as plt
    im = im[:, :, (2, 1, 0)]
    plt.cla()
    plt.imshow(im)
    
    max_rels = 5
    scores = scores[:, 1:]
    image_scores = scores.flatten()
    if len(image_scores) > 10:
        image_thresh = np.sort(image_scores)[-max_rels]
    else:
        image_thresh = -np.inf
        
    for i in xrange(n):
        for j in xrange(n):
            keep = np.where(scores[i*n+j] >= image_thresh)[0]
            for ix in keep:
                bbox = box_proposals[i]
                score = scores[i*n+j, ix]
                plt.gca().add_patch(
                    plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1], fill=False,
                                  edgecolor='red', linewidth=1)
                        )
                        
                plt.gca().text(bbox[0], bbox[1] - 2,
                            '{:s} {:.3f}'.format(class_names[ix], score),
                            bbox=dict(facecolor='blue', alpha=0.5),
                            fontsize=8, color='white')   

                bbox = box_proposals[j]
                plt.gca().add_patch(
                    plt.Rectangle((bbox[0], bbox[1]),
                                  bbox[2] - bbox[0],
                                  bbox[3] - bbox[1], fill=False,
                                  edgecolor='red', linewidth=1)
                        )
                
    plt.title('Best %d Relations using gt boxes' % max_rels)
    plt.show()
    plt.savefig('./data/vis/%s' % filename)
    

def apply_nms(all_boxes, thresh):
    """Apply non-maximum suppression to all predicted boxes output by the
    test_net method.
    """
    num_classes = len(all_boxes)
    num_images = len(all_boxes[0])
    nms_boxes = [[[] for _ in xrange(num_images)]
                 for _ in xrange(num_classes)]
    for cls_ind in xrange(num_classes):
        for im_ind in xrange(num_images):
            dets = all_boxes[cls_ind][im_ind]
            if dets == []:
                continue
            # CPU NMS is much faster than GPU NMS when the number of boxes
            # is relative small (e.g., < 10k)
            # TODO(rbg): autotune NMS dispatch
            keep = nms(dets, thresh, force_cpu=True)
            if len(keep) == 0:
                continue
            nms_boxes[cls_ind][im_ind] = dets[keep, :].copy()
    return nms_boxes 
    

def test_net(net, imdb, max_per_image=400, thresh=-np.inf, vis=False, load_cache=False):
    """Test a Fast R-CNN network on an image database."""
    num_images = len(imdb.image_index)
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in xrange(num_images)]
                 for _ in xrange(imdb.num_classes)]             

    output_dir = get_output_dir(imdb, net)
    det_file = os.path.join(output_dir, 'detections.pkl')
    if load_cache and os.path.exists(det_file):
        print 'Loading pickled detections from %s' % det_file
        with open(det_file, 'rb') as f:
            all_boxes = cPickle.load(f)
    
    else:
        # timers
        _t = {'im_detect' : Timer(), 'misc' : Timer()}

        if not cfg.TEST.HAS_RPN:
            roidb = imdb.roidb

        for i in xrange(num_images):
            # filter out any ground truth boxes
            if cfg.TEST.HAS_RPN:
                box_proposals = None
            else:
                # The roidb may contain ground-truth rois (for example, if the roidb
                # comes from the training or val split). We only want to evaluate
                # detection on the *non*-ground-truth rois. We select those the rois
                # that have the gt_classes field set to 0, which means there's no
                # ground truth.
                box_proposals = roidb[i]['boxes'][roidb[i]['gt_classes'] == 0]
                
            im = cv2.imread(imdb.image_path_at(i))
            _t['im_detect'].tic()
            scores, boxes, attr_scores, rel_scores = im_detect(net, im, box_proposals)
            _t['im_detect'].toc()

            _t['misc'].tic()
            # skip j = 0, because it's the background class
            for j in xrange(1, imdb.num_classes):
                inds = np.where(scores[:, j] > thresh)[0]
                cls_scores = scores[inds, j]
                if cfg.TEST.AGNOSTIC:
                    cls_boxes = boxes[inds, 4:8]
                else:
                    cls_boxes = boxes[inds, j*4:(j+1)*4]
                cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \
                          .astype(np.float32, copy=False)
                #keep = soft_nms(cls_dets, method=cfg.TEST.SOFT_NMS)
                keep = nms(cls_dets, cfg.TEST.NMS)
                cls_dets = cls_dets[keep, :]
                if vis:
                    vis_detections(im, imdb.classes[j], cls_dets)
                all_boxes[j][i] = cls_dets

            # Limit to max_per_image detections *over all classes*
            if max_per_image > 0:
                image_scores = np.hstack([all_boxes[j][i][:, 4]
                                          for j in xrange(1, imdb.num_classes)])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in xrange(1, imdb.num_classes):
                        keep = np.where(all_boxes[j][i][:, 4] >= image_thresh)[0]
                        all_boxes[j][i] = all_boxes[j][i][keep, :]                        
                        
            _t['misc'].toc()

            print 'im_detect: {:d}/{:d} {:.3f}s {:.3f}s' \
                  .format(i + 1, num_images, _t['im_detect'].average_time,
                          _t['misc'].average_time)

        with open(det_file, 'wb') as f:
            cPickle.dump(all_boxes, f, cPickle.HIGHEST_PROTOCOL)

    print 'Evaluating detections'
    imdb.evaluate_detections(all_boxes, output_dir)
    
    
def test_net_with_gt_boxes(net, imdb, max_per_image=400, thresh=-np.inf, vis=False, load_cache=False):
    """Test a Fast R-CNN network on an image database, evaluating attribute 
       and relation detections given ground truth boxes."""
    num_images = len(imdb.image_index)
    # all detections are collected into:
    #    all_boxes[cls][image] = N x 5 array of detections in
    #    (x1, y1, x2, y2, score)
    all_boxes = [[[] for _ in xrange(num_images)]
                 for _ in xrange(imdb.num_attributes)]
    rel_boxes = [[[] for _ in xrange(num_images)]
                 for _ in xrange(imdb.num_relations)]                  

    output_dir = get_output_dir(imdb, net, attributes=True)
    det_file = os.path.join(output_dir, 'attribute_detections.pkl')
    rel_file = os.path.join(output_dir, 'relation_detections.pkl')
    if load_cache and os.path.exists(det_file):
        print 'Loading pickled detections from %s' % det_file
        with open(det_file, 'rb') as f:
            all_boxes = cPickle.load(f)
        with open(rel_file, 'rb') as f:
            rel_boxes = cPickle.load(f)
    
    else:
        # timers
        _t = {'im_detect' : Timer(), 'misc' : Timer()}

        roidb = imdb.gt_roidb()

        for i in xrange(num_images):
            box_proposals = roidb[i]['boxes']
                
            im = cv2.imread(imdb.image_path_at(i))
            _t['im_detect'].tic()
            scores, boxes, attr_scores, rel_scores = im_detect(net, im, box_proposals, force_boxes=True)
            _t['im_detect'].toc()

            _t['misc'].tic()
            # skip j = 0, because it's the no attribute class
            if attr_scores.shape[1] < imdb.num_attributes:
                attr_scores = np.hstack((np.zeros((attr_scores.shape[0],1)),attr_scores))
            if rel_scores and rel_scores.shape[1] < imdb.num_relations:
                rel_scores = np.hstack((np.zeros((rel_scores.shape[0],1)),rel_scores))
            for j in xrange(1, imdb.num_attributes):
                inds = np.where(attr_scores[:, j] > thresh)[0]
                cls_scores = attr_scores[inds, j]
                cls_boxes = box_proposals[inds, :]
                cls_dets = np.hstack((cls_boxes, cls_scores[:, np.newaxis])) \
                          .astype(np.float32, copy=False)
                all_boxes[j][i] = cls_dets
                
            # Limit to max_per_image detections *over all attributes*
            if max_per_image > 0:
                image_scores = np.hstack([all_boxes[j][i][:, 4]
                                          for j in xrange(1, imdb.num_attributes)])
                if len(image_scores) > max_per_image:
                    image_thresh = np.sort(image_scores)[-max_per_image]
                    for j in xrange(1, imdb.num_attributes):
                        keep = np.where(all_boxes[j][i][:, 4] >= image_thresh)[0]
                        all_boxes[j][i] = all_boxes[j][i][keep, :]

            if vis:
                im_boxes = [all_boxes[j][i] for j in xrange(imdb.num_attributes)]
                vis_multiple(im, imdb.attributes, im_boxes, filename='attr_%d.png' % i)
                if rel_scores:
                    vis_relations(im, imdb.relations, box_proposals, rel_scores, filename='rel_%d.png' % i)                
                        
            _t['misc'].toc()

            print 'im_detect: {:d}/{:d} {:.3f}s {:.3f}s' \
                  .format(i + 1, num_images, _t['im_detect'].average_time,
                          _t['misc'].average_time)

        with open(det_file, 'wb') as f:
            cPickle.dump(all_boxes, f, cPickle.HIGHEST_PROTOCOL)

    print 'Evaluating attribute and / or relation detections'
    imdb.evaluate_attributes(all_boxes, output_dir)    
